import torch
from icecream import ic

from qa.table_bert.hm_table import *
from qa.tuta.reader import HMTReader
from qa.tuta.tokenizer import HMTTokenizer, PAD_ID
from qa.tuta.utils import UNZIPS


class TUTAPreprocessor(object):
    """ Preprocessor that converts cell matrix in hmt to tuta input batch."""
    def __init__(self, config):
        self.config = config
        self.reader = HMTReader(config)
        self.tokenizer = HMTTokenizer(config)
        # extra config
        self.config.total_node = sum(self.config.node_degree)
        self.config.vocab_size = len(self.tokenizer.vocab)
        self.config.default_pos = [self.config.total_node] * self.config.tree_depth
        self.config.default_format = [0.25, 0.25, 0., 0., 0., 0., 0., 0., 0., 1., 1.]
        # debug
        self.total = 0
        self.skip = 0

    def pipeline(self, contexts: List[List[str]], tables: List[HMTable]):
        """ Pipeline to preprocess env_contexts into TUTA inputs. """
        instances = self.read(tables)
        instances = self.tokenize_str(contexts, instances)
        instances = self.sample_and_prepare(tables, instances)
        instances = self.linearize(tables, instances)
        return instances

    def read(self, tables):
        new_instances = []
        for t in tables:
            new_instances.append(self.reader.result_from_table(t))
            # ic(new_instances[-1][1])
            # ic(t.header2level)
            # ic(t.index_name2id)
            # ic(new_instances[-1][-2])
            # ic(new_instances[-1][-1])
        """
        string_matrix: 
            [['exports', '2018', '2019', 'percent change'],
              ['european union', '2,838', '3,549', '25.1'],
              ['united states', '2,666', '2,867', '7.5'],
              ['canada', '1,277', '1,284', '0.5'],
              ['brazil', '722', '861', '19.3'],
              ['china', '202', '135', '-33.2'],
              ['mexico', '177', '234', '32.2'],
              ['japan', '4', '3', '-25'],
              ['south korea', '1', '1', '0'],
              ['philippines', '1', '1', '0'],
              ['other countries', '358', '398', '11.2'],
              ['total', '8,246', '9,335', '13.2']]
        position_lists:  (top_tree_coord, left_tree_coord), top header has -1 left coord, and vice versa. 
             ([[-1, -1, -1, -1],
               [-1, -1, -1, 0],
               [-1, -1, -1, 1],
               [-1, -1, -1, 2],
               [-1, -1, -1, -1],
               [-1, -1, -1, 0],
               [-1, -1, -1, 1],
               [-1, -1, -1, 2],
                ...
               [-1, -1, -1, 2]],
              [[-1, -1, -1, -1],
               [-1, -1, -1, -1],
               [-1, -1, -1, -1],
               [-1, -1, -1, -1],
               [-1, -1, -1, 0],
               [-1, -1, -1, 0],
               [-1, -1, -1, 0],
               [-1, -1, -1, 0],
               [-1, -1, -1, 1],
                ...
               [-1, -1, -1, 10]])
        header_info: (num_of_top_tree_levels, num_of_left_tree_levels)
            (2, 2)
        context:  (title)
            'table 1: total world pork exports'
        header_matrix:
            [[-1, 11, 12, 13],
            [0, -1, -1, -1],
            [1, -1, -1, -1],
            [2, -1, -1, -1],
            [3, -1, -1, -1],
            [4, -1, -1, -1],
            [5, -1, -1, -1],
            [6, -1, -1, -1],
            [7, -1, -1, -1],
            [8, -1, -1, -1],
            [9, -1, -1, -1],
            [10, -1, -1, -1]]
        level_matrix:
            [[-1, 3, 3, 3],
            [1, -1, -1, -1],
            [1, -1, -1, -1],
            [1, -1, -1, -1],
            [1, -1, -1, -1],
            [1, -1, -1, -1],
            [1, -1, -1, -1],
            [1, -1, -1, -1],
            [1, -1, -1, -1],
            [1, -1, -1, -1],
            [1, -1, -1, -1],
            [1, -1, -1, -1]]
        """
        return new_instances

    def tokenize_str(self, input_contexts, instances):
        new_instances = []
        i = 0
        for (string_matrix, position_lists, header_info, context, header_matrix, level_matrix, top_header_coord_set, left_header_coord_set) in instances:
            token_matrix, number_matrix = self.tokenizer.tokenize_string_matrix(string_matrix, add_separate=self.config.add_separate)
            if input_contexts[i] != '':  # if input question, use question as contexts
                context = input_contexts[i]
            new_instances.append((
                token_matrix,
                number_matrix,
                position_lists,
                header_info,
                context,
                header_matrix,
                level_matrix,
                top_header_coord_set,
                left_header_coord_set
            ))
            i += 1
        """
        'string matrix' is tokenized into 'token_matrix' and 'number_matrix'
        token_matrix:
            [[[102, 14338], [102, 2760], [102, 10476], [102, 3867, 2689]],
              [[102, 2647, 2586], [102, 1], [102, 1], [102, 1]],
              [[102, 2142, 2163], [102, 1], [102, 1], [102, 1]],
              [[102, 2710], [102, 1], [102, 1], [102, 1]],
              [[102, 4380], [102, 1], [102, 1], [102, 1]],
              [[102, 2859], [102, 16798], [102, 11502], [102, 1011, 1]],
              [[102, 3290], [102, 18118], [102, 22018], [102, 1]],
              [[102, 2900], [102, 1018], [102, 1017], [102, 1011, 2423]],
              [[102, 2148, 4420], [102, 1015], [102, 1015], [102, 1014]],
              [[102, 5137], [102, 1015], [102, 1015], [102, 1014]],
              [[102, 2060, 3032], [102, 1], [102, 1], [102, 1]],
              [[102, 2561], [102, 1], [102, 1], [102, 1]]]
        number_matrix:
            [[[(11, 11, 11, 11), (11, 11, 11, 11)],
               [(11, 11, 11, 11), (4, 0, 2, 8)],
               [(11, 11, 11, 11), (4, 0, 2, 9)],
               [(11, 11, 11, 11), (11, 11, 11, 11), (11, 11, 11, 11)]],
                ...
              [[(11, 11, 11, 11), (11, 11, 11, 11)],
               [(11, 11, 11, 11), (4, 0, 8, 6)],
               [(11, 11, 11, 11), (4, 0, 9, 5)],
               [(11, 11, 11, 11), (2, 1, 1, 2)]]]

        """
        return new_instances

    def sample_and_prepare(self, tables, instances):
        new_instances = []
        for table, instance in zip(tables, instances):
            token_matrix, number_matrix, position_lists, header_info, format_or_text, header_matrix, level_matrix, top_header_coord_set, left_header_coord_set = instance
            format_matrix, context = None, None
            context = format_or_text

            if self.config.hier_or_flat == "hier":
                header_rows, header_columns = header_info
                if (header_rows <= 1) and (header_columns <= 1):
                    continue
            elif self.config.hier_or_flat == "flat":
                header_rows, header_columns = header_info
                if (header_rows > 1) or (header_columns > 1):
                    continue

            sampling_matrix = self.tokenizer.sampling(
                token_matrix=token_matrix,
                number_matrix=number_matrix,
                header_info=header_info,
                header_matrix=header_matrix,
                max_disturb_num=self.config.max_disturb_num,
                disturb_prob=self.config.disturb_prob,
                clc_rate=self.config.clc_rate
            )
            # print(sampling_matrix)
            # print(number_matrix)
            results = self.tokenizer.create_table_seq(
                table=table,
                sampling_matrix=sampling_matrix,
                token_matrix=token_matrix,
                number_matrix=number_matrix,
                position_lists=position_lists,
                format_matrix=format_matrix,
                context=context,
                header_matrix=header_matrix,
                level_matrix=level_matrix,
                top_header_coord_set=top_header_coord_set,
                left_header_coord_set=left_header_coord_set,
                add_sep=self.config.add_separate
            )
            if (results is None) or (len(results[0]) > self.config.max_cell_num):
                ic(results)
                self.skip += 1
                continue
            # token_seq = [tok for cell in results[0] for tok in cell]
            # if len(token_seq) > self.config.max_seq_len:
            #     ic(len(token_seq))
            #     continue
            self.total += 1
            # ic(self.skip, self.total)
            new_instances.append(results)
        """
        token_list:
            [[101, 2023, 2003, 3160, 2028, 1012],
                          [102, 14338],
                          [102, 2760],
                          [102, 10476],
                            ...
                          [102, 2060, 3032],
                          [102, 1],
                          [102, 2561],
                          [102, 1]]
        num_list:
            [[(11, 11, 11, 11),(11, 11, 11, 11),(11, 11, 11, 11),(11, 11, 11, 11),(11, 11, 11, 11), (11, 11, 11, 11)],
                  [(11, 11, 11, 11), (11, 11, 11, 11)],
                  [(11, 11, 11, 11), (4, 0, 2, 8)],
                  [(11, 11, 11, 11), (4, 0, 2, 9)],
                    ...
                  [(11, 11, 11, 11), (11, 11, 11, 11), (11, 11, 11, 11)],
                  [(11, 11, 11, 11), (3, 0, 3, 8)],
                  [(11, 11, 11, 11), (11, 11, 11, 11)],
                  [(11, 11, 11, 11), (4, 0, 8, 6)]]
        pos_list:
            [(256, 256, [-1, -1, -1, -1], [-1, -1, -1, -1]),
                  (0, 0, [-1, -1, -1, -1], [-1, -1, -1, -1]),
                  (0, 1, [-1, -1, -1, 0], [-1, -1, -1, -1]),
                  (0, 2, [-1, -1, -1, 1], [-1, -1, -1, -1]),
                    ...
                  (10, 0, [-1, -1, -1, -1], [-1, -1, -1, 9]),
                  (10, 1, [-1, -1, -1, 0], [-1, -1, -1, 9]),
                  (11, 0, [-1, -1, -1, -1], [-1, -1, -1, 10]),
                  (11, 1, [-1, -1, -1, 0], [-1, -1, -1, 10])]
        fmt_list:
            [[0.25, 0.25, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0],
                  [0.0625, 0.0625, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0],
                  [0.0625, 0.0625, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0],
                  [0.0625, 0.0625, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0],
                    ...
                  [0.0625, 0.0625, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0],
                  [0.0625, 0.0625, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0],
                  [0.0625, 0.0625, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0],
                  [0.0625, 0.0625, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0]]
        
        ind_list:
            [[-1, -2, -2, -2, -2, -2],
                  [1, 2],
                  [3, 4],
                  [5, 6],
                    ...
                  [51, 52, 52],
                  [53, 54],
                  [55, 56],
                  [57, 58]]
        """
        return new_instances

    def linearize(self, tables, instances):
        batch_max_seq_len = 0
        all_token_id, all_num_mag, all_num_pre, all_num_top, all_num_low = [], [], [], [], []
        all_token_order, all_pos_row, all_pos_col, all_pos_top, all_pos_left = [], [], [], [], []
        all_format_vec, all_indicator = [], []

        all_context_mask, all_context_token_indices = [], []
        all_header_mask, all_header_token_indices = [], []
        all_level_mask, all_level_token_indices = [], []
        for table, instance in zip(tables, instances):
            tok_list, num_list, pos_list, fmt_list, cell_ind, \
            context_mask, context_token_indices, header_mask, header_token_indices, level_mask, level_token_indices = instance
            token_id, num_mag, num_pre, num_top, num_low = [], [], [], [], []
            token_order, pos_row, pos_col, pos_top, pos_left = [], [], [], [], []
            format_vec, indicator = [], []

            cell_num = len(tok_list)
            for icell in range(cell_num):  # cell_level
                tokens = tok_list[icell]
                cell_len = len(tokens)
                token_id.extend(tokens)
                token_order.extend([ii for ii in range(cell_len)])

                num_feats = num_list[icell]
                num_mag.extend([f[0] for f in num_feats])
                num_pre.extend([f[1] for f in num_feats])
                num_top.extend([f[2] for f in num_feats])
                num_low.extend([f[3] for f in num_feats])

                row, col, ttop, tleft = pos_list[icell]
                pos_row.extend([row for _ in range(cell_len)])
                pos_col.extend([col for _ in range(cell_len)])  # FIXME: node_degree is not split.
                entire_top = UNZIPS[self.config.target](ttop, self.config.node_degree, self.config.total_node)
                pos_top.extend([entire_top for _ in range(cell_len)])
                entire_left = UNZIPS[self.config.target](tleft, self.config.node_degree, self.config.total_node)
                pos_left.extend([entire_left for _ in range(cell_len)])

                format_vec.extend([fmt_list[icell] for _ in range(cell_len)])
                indicator.extend(cell_ind[icell])



            seq_len = len(token_id)
            # if seq_len > self.config.max_seq_len:  # stop if exceed seq_len bound
            #     continue
            batch_max_seq_len = max(batch_max_seq_len, seq_len)

            # append to overall instance set
            all_token_id.append(token_id)
            all_num_mag.append(num_mag)
            all_num_pre.append(num_pre)
            all_num_top.append(num_top)
            all_num_low.append(num_low)
            all_token_order.append(token_order)
            all_pos_row.append(pos_row)
            all_pos_col.append(pos_col)
            all_pos_top.append(pos_top)
            all_pos_left.append(pos_left)
            all_format_vec.append(format_vec)
            all_indicator.append(indicator)

            all_context_mask.append(context_mask)
            all_context_token_indices.append(context_token_indices)
            all_header_mask.append(header_mask)
            all_header_token_indices.append(header_token_indices)
            all_level_mask.append(level_mask)
            all_level_token_indices.append(level_token_indices)

            try:
                assert len(token_id) == len(num_mag) == len(token_order) == len(pos_row) == len(format_vec) \
                   == len(indicator) == len(header_token_indices) == len(level_token_indices)
            except:
                ic(len(token_id), len(num_mag), len(token_order), len(pos_row), len(format_vec), len(indicator),
                   len(header_token_indices), len(level_token_indices))
            # print(f"context_mask: {context_mask}")
            # print(f"context_token_indices: {context_token_indices}")
            # print(f"header_mask: {header_mask}")
            # print(f"header_token_indices: {header_token_indices}")
            # print(f"level_mask: {level_mask}")
            # print(f"level_token_indices: {level_token_indices}")
            # print(table.header2id)
            # print(table.index_name2id)
            # tok_str_list = [self.tokenizer.convert_ids_to_tokens(tok_ids) for tok_ids in tok_list]
            # print(f"token_list\n: {tok_str_list}")
            # print(table.matrix_dict['Texts'])
            # print()
            """
            NOT padded yet.
            token_id:
                [101, 2023, 2003, 3160, 2028, 1012, 102, 14338, 102, 2760, 102, 10476, 102, 3867, 2689, 102, 2647, 2586, 102, 1, 102, 1, 102, 1, 102, 2142, 2163, 102, 1, 102, 2710, 102, 1, 102, 4380, 102, 1, 102, 1, 102, 2859, 102, 16798, 102, 3290, 102, 18118, 102, 2900, 102, 1018, 102, 2148, 4420, 102, 1015, 102, 5137, 102, 1015, 102, 2060, 3032, 102, 1, 102, 2561, 102, 1]
            num_mag(num_pre, num_top, num_low are similar)
                [11, 11, 11, 11, 11, 11, 11, 11, 11, 4, 11, 4, 11, 11, 11, 11, 11, 11, 11, 4, 11, 4, 11, 2, 11, 11, 11, 11, 4, 11, 11, 11, 4, 11, 11, 11, 3, 11, 2, 11, 11, 11, 3, 11, 11, 11, 3, 11, 11, 11, 1, 11, 11, 11, 11, 1, 11, 11, 11, 1, 11, 11, 11, 11, 3, 11, 11, 11, 4]
            token_order:
                [0, 1, 2, 3, 4, 5, 0, 1, 0, 1, 0, 1, 0, 1, 2, 0, 1, 2, 0, 1, 0, 1, 0, 1, 0, 1, 2, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 2, 0, 1, 0, 1, 0, 1, 0, 1, 2, 0, 1, 0, 1, 0, 1]
            pos_row:
                [256, 256, 256, 256, 256, 256, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 5, 5, 5, 5, 6, 6, 6, 6, 7, 7, 7, 7, 8, 8, 8, 8, 8, 9, 9, 9, 9, 10, 10, 10, 10, 10, 11, 11, 11, 11]
            pos_col:
                [256, 256, 256, 256, 256, 256, 0, 0, 1, 1, 2, 2, 3, 3, 3, 0, 0, 0, 1, 1, 2, 2, 3, 3, 0, 0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1, 3, 3, 0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 0, 1, 1, 0, 0, 1, 1]
            pos_top:
                [[384, 384, 384, 384], [384, 384, 384, 384], [384, 384, 384, 384], [384, 384, 384, 384], [384, 384, 384, 384], [384, 384, 384, 384], [384, 384, 384, 384], [384, 384, 384, 384], [384, 384, 384, 128], [384, 384, 384, 128], [384, 384, 384, 129], [384, 384, 384, 129], [384, 384, 384, 130], [384, 384, 384, 130], [384, 384, 384, 130], [384, 384, 384, 384], [384, 384, 384, 384], [384, 384, 384, 384], [384, 384, 384, 128], [384, 384, 384, 128], [384, 384, 384, 129], [384, 384, 384, 129], [384, 384, 384, 130], [384, 384, 384, 130], [384, 384, 384, 384], [384, 384, 384, 384], [384, 384, 384, 384], [384, 384, 384, 128], [384, 384, 384, 128], [384, 384, 384, 384], [384, 384, 384, 384], [384, 384, 384, 128], [384, 384, 384, 128], [384, 384, 384, 384], [384, 384, 384, 384], [384, 384, 384, 128], [384, 384, 384, 128], [384, 384, 384, 130], [384, 384, 384, 130], [384, 384, 384, 384], [384, 384, 384, 384], [384, 384, 384, 128], [384, 384, 384, 128], [384, 384, 384, 384], [384, 384, 384, 384], [384, 384, 384, 128], [384, 384, 384, 128], [384, 384, 384, 384], [384, 384, 384, 384], [384, 384, 384, 128], [384, 384, 384, 128], [384, 384, 384, 384], [384, 384, 384, 384], [384, 384, 384, 384], [384, 384, 384, 128], [384, 384, 384, 128], [384, 384, 384, 384], [384, 384, 384, 384], [384, 384, 384, 128], [384, 384, 384, 128], [384, 384, 384, 384], [384, 384, 384, 384], [384, 384, 384, 384], [384, 384, 384, 128], [384, 384, 384, 128], [384, 384, 384, 384], [384, 384, 384, 384], [384, 384, 384, 128], [384, 384, 384, 128]]
            pos_left:
                [[384, 384, 384, 384], [384, 384, 384, 384], [384, 384, 384, 384], [384, 384, 384, 384], [384, 384, 384, 384], [384, 384, 384, 384], [384, 384, 384, 384], [384, 384, 384, 384], [384, 384, 384, 384], [384, 384, 384, 384], [384, 384, 384, 384], [384, 384, 384, 384], [384, 384, 384, 384], [384, 384, 384, 384], [384, 384, 384, 384], [384, 384, 384, 128], [384, 384, 384, 128], [384, 384, 384, 128], [384, 384, 384, 128], [384, 384, 384, 128], [384, 384, 384, 128], [384, 384, 384, 128], [384, 384, 384, 128], [384, 384, 384, 128], [384, 384, 384, 129], [384, 384, 384, 129], [384, 384, 384, 129], [384, 384, 384, 129], [384, 384, 384, 129], [384, 384, 384, 130], [384, 384, 384, 130], [384, 384, 384, 130], [384, 384, 384, 130], [384, 384, 384, 131], [384, 384, 384, 131], [384, 384, 384, 131], [384, 384, 384, 131], [384, 384, 384, 131], [384, 384, 384, 131], [384, 384, 384, 132], [384, 384, 384, 132], [384, 384, 384, 132], [384, 384, 384, 132], [384, 384, 384, 133], [384, 384, 384, 133], [384, 384, 384, 133], [384, 384, 384, 133], [384, 384, 384, 134], [384, 384, 384, 134], [384, 384, 384, 134], [384, 384, 384, 134], [384, 384, 384, 135], [384, 384, 384, 135], [384, 384, 384, 135], [384, 384, 384, 135], [384, 384, 384, 135], [384, 384, 384, 136], [384, 384, 384, 136], [384, 384, 384, 136], [384, 384, 384, 136], [384, 384, 384, 137], [384, 384, 384, 137], [384, 384, 384, 137], [384, 384, 384, 137], [384, 384, 384, 137], [384, 384, 384, 138], [384, 384, 384, 138], [384, 384, 384, 138], [384, 384, 384, 138]]
            format_vec:
                [[0.25, 0.25, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0], [0.25, 0.25, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0], [0.25, 0.25, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0], [0.25, 0.25, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0], [0.25, 0.25, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0], [0.25, 0.25, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0], [0.0625, 0.0625, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0], [0.0625, 0.0625, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0], [0.0625, 0.0625, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0], [0.0625, 0.0625, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0], ...]
            indicator:
                [-1, -2, -2, -2, -2, -2, 1, 2, 3, 4, 5, 6, 7, 8, 8, 9, 10, 10, 11, 12, 13, 14, 15, 16, 17, 18, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 44, 45, 46, 47, 48, 49, 50, 51, 52, 52, 53, 54, 55, 56, 57, 58]
            """

        # pad things to batch_max_seq_len
        max_context_len = max(len(instance[5]) for instance in instances)
        max_level_num = max(len(table.index_name2id) for table in tables)
        max_header_num = max(len(table.header2id) for table in tables)
        batch_max_seq_len = ((batch_max_seq_len + 7) // 8) * 8
        for isample in range(len(all_token_id)):
            all_token_id[isample].extend([PAD_ID] * (batch_max_seq_len - len(all_token_id[isample])))
            all_num_mag[isample].extend([self.config.magnitude_size + 1] * (batch_max_seq_len - len(all_num_mag[isample])))
            all_num_pre[isample].extend([self.config.precision_size + 1] * (batch_max_seq_len - len(all_num_pre[isample])))
            all_num_top[isample].extend([self.config.top_digit_size + 1] * (batch_max_seq_len - len(all_num_top[isample])))
            all_num_low[isample].extend([self.config.low_digit_size + 1] * (batch_max_seq_len - len(all_num_low[isample])))

            all_token_order[isample].extend([0] * (batch_max_seq_len - len(all_token_order[isample])))
            all_pos_row[isample].extend([self.config.row_size] * (batch_max_seq_len - len(all_pos_row[isample])))
            all_pos_col[isample].extend([self.config.column_size] * (batch_max_seq_len - len(all_pos_col[isample])))
            all_pos_top[isample].extend([self.config.default_pos] * (batch_max_seq_len - len(all_pos_top[isample])))
            all_pos_left[isample].extend([self.config.default_pos] * (batch_max_seq_len - len(all_pos_left[isample])))
            all_format_vec[isample].extend([self.config.default_format] * (batch_max_seq_len - len(all_format_vec[isample])))
            all_indicator[isample].extend([0] * (batch_max_seq_len - len(all_indicator[isample])))

            # mask for hmt
            all_context_mask[isample].extend([0] * (max_context_len - len(all_context_mask[isample])))
            all_context_token_indices[isample].extend([0] * (max_context_len - len(all_context_token_indices[isample])))
            all_header_mask[isample].extend([0] * (max_header_num - len(all_header_mask[isample])))
            all_header_token_indices[isample].extend([-1] * (batch_max_seq_len - len(all_header_token_indices[isample])))
            all_level_mask[isample].extend([0] * (max_level_num - len(all_level_mask[isample])))
            all_level_token_indices[isample].extend([-1] * (batch_max_seq_len - len(all_level_token_indices[isample])))
            # truncate
            all_token_id[isample] = all_token_id[isample][: self.config.max_seq_len]
            all_num_mag[isample] = all_num_mag[isample][: self.config.max_seq_len]
            all_num_pre[isample] = all_num_pre[isample][: self.config.max_seq_len]
            all_num_top[isample] = all_num_top[isample][: self.config.max_seq_len]
            all_num_low[isample] = all_num_low[isample][: self.config.max_seq_len]
            all_token_order[isample] = all_token_order[isample][: self.config.max_seq_len]
            all_pos_row[isample] = all_pos_row[isample][: self.config.max_seq_len]
            all_pos_col[isample] = all_pos_col[isample][: self.config.max_seq_len]
            all_pos_top[isample] = all_pos_top[isample][: self.config.max_seq_len]
            all_pos_left[isample] = all_pos_left[isample][: self.config.max_seq_len]
            all_format_vec[isample] = all_format_vec[isample][: self.config.max_seq_len]
            all_indicator[isample] = all_indicator[isample][: self.config.max_seq_len]
            all_header_token_indices[isample] = all_header_token_indices[isample][: self.config.max_seq_len]
            all_level_token_indices[isample] = all_level_token_indices[isample][: self.config.max_seq_len]
        # print()
        # for h in all_header_token_indices:
        #     print(len(h))
        header_token_indices = torch.LongTensor(all_header_token_indices)
        header_token_indices[header_token_indices == -1] = max_header_num
        level_token_indices = torch.LongTensor(all_level_token_indices)
        level_token_indices[level_token_indices == -1] = max_level_num

        # bert position ids
        all_position_ids, all_segment_ids = [], []
        for isample in range(len(all_token_id)):
            max_seq_len = min(batch_max_seq_len, self.config.max_seq_len)
            all_position_ids.append(list(range(max_seq_len)))
            num_q_tokens = max(all_context_token_indices[isample]) + 1
            segment_ids = [0] * num_q_tokens + [1] * (max_seq_len - num_q_tokens)
            all_segment_ids.append(segment_ids)

        return (
            torch.LongTensor(all_token_id),  # b * max_seq_len * 1, features together: b * max_seq_len * 12
            torch.LongTensor(all_num_mag),
            torch.LongTensor(all_num_pre),
            torch.LongTensor(all_num_top),
            torch.LongTensor(all_num_low),
            torch.LongTensor(all_token_order),
            torch.LongTensor(all_pos_row),
            torch.LongTensor(all_pos_col),
            torch.LongTensor(all_pos_top),
            torch.LongTensor(all_pos_left),
            torch.FloatTensor(all_format_vec),
            torch.LongTensor(all_indicator),
            torch.LongTensor(all_position_ids),
            torch.LongTensor(all_segment_ids)
        ), dict(
            context_token_mask=torch.LongTensor(all_context_mask),
            context_token_indices=torch.LongTensor(all_context_token_indices),
            header_mask=torch.LongTensor(all_header_mask),
            header_token_indices=header_token_indices,
            level_mask=torch.LongTensor(all_level_mask),
            level_token_indices=level_token_indices
        )